''' 
Combined API server for serving LLM inference requests and FAISS vector database operations.
Loads a language model and provides endpoints for inference and vector similarity search.
This server uses FastAPI to handle requests and responses for both language model inference
and FAISS-based question similarity search.

To run for testing:
CUDA_VISIBLE_DEVICES=0 python /scratch/dhoward/Chatbot/model_server.py
To run this command locally, hba.conf files will need to be edited to specify single gpu usage. 

For server running:
sudo supervisorctl start / stop / status myadvisor-chatbot-model-server
'''

from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
# Initialize the llm
from chatbot import answer_question_streaming, road_map_evaluation_simple, road_map_evaluation, InferenceRequest, RoadmapEvaluation
import traceback
import torch
import json
import requests
import faissbot
from typing import List, Optional
# Request manager
from request_queue_manager import request_queue, QueueFullError
app = FastAPI(title="Combined Model Server", description="Serves LLM inference requests and FAISS vector database operations")

# API URL for external database
api_url = "https://myadvisor.cs.uct.ac.za/backend/"

# Pydantic FAISS models
class SearchRequest(BaseModel):
    question: str
    k: int = 3  
    max_results: int = 2  
    min_similarity: float = 0.15  # Minimum keyword overlap required

class ReinitializeRequest(BaseModel):
    questions: List[str]

class SearchResponse(BaseModel):
    indices: List[int]  
    query: str
    k: int

class StatsResponse(BaseModel):
    total_questions: int
    model_loaded: bool
    index_initialized: bool
    dimension: int
    index_size: int

class RoadmapResponse(BaseModel):
    evaluation: str
    success: bool = True


@app.get("/queue-status")
async def get_queue_status():
    ''' get current queue status '''
    return request_queue.get_queue_status()

@app.get("/")
async def root():
    """Main health check endpoint with queue status"""
    faiss_stats = faissbot.get_database_stats()
    faiss_ready = faissbot.is_initialized()
    queue_status = request_queue.get_queue_status()
    
    return {
        "message": "Combined Model Server with FAISS Vector Database API is running",
        "llm_status": "ready",
        "faiss_status": "ready" if faiss_ready else "not_initialized",
        "device": "cuda" if torch.cuda.is_available() else "cpu",
        "faiss_ready_for_search": faiss_ready,
        "queue_status": queue_status,
        **faiss_stats
    }

@app.get("/health")
async def health_check():
    """Detailed health check endpoint with queue status"""
    try:
        faiss_stats = faissbot.get_database_stats()
        faiss_ready = faissbot.is_initialized()
        queue_status = request_queue.get_queue_status()
        
        return {
            "api_status": "healthy",
            "llm_status": "healthy",
            "model_loaded": True,
            "device": "cuda" if torch.cuda.is_available() else "cpu",
            "faiss_database_initialized": faiss_ready,
            "faiss_can_perform_searches": faiss_ready,
            "faiss_total_questions": faiss_stats["total_questions"],
            "faiss_model_loaded": faiss_stats["model_loaded"],
            "faiss_index_ready": faiss_stats["index_initialized"],
            "request_queue": queue_status
        }
    except Exception as e:
        return {
            "api_status": "error",
            "error": str(e),
            "llm_status": "healthy",
            "faiss_database_initialized": False,
            "faiss_can_perform_searches": False,
            "request_queue": request_queue.get_queue_status()
        }

# LLM Inference endpoints
@app.post("/infer/stream-json")
async def infer_stream_json(request: InferenceRequest):
    """
    JSON streaming endpoint for LLM inference.
    Adds the request to the queue and streams back results once processed.
    """
    try:
        # Step 1. Add to queue
        try:
            request_id, future = await request_queue.add_request({
                "question": request.question,
                "context": request.context,
                "vector_info": request.vector_info,
                "database_info": request.database_info
            })
        except QueueFullError:
            raise HTTPException(
                status_code=503,
                detail="The chatbot is VERY busy right now. Please try again in a moment."
            )
        # Step 2. Wait for worker to process and prepare inference data
        inference_data = await future
        processed_request = InferenceRequest(
            question = inference_data.get("question", ""),
            context = inference_data.get("context", ""),
            vector_info = inference_data.get("vector_info", []),
            database_info = inference_data.get("database_info", {})
        )
        # Step 3. Stream model results as newline-delimited JSON
        async def generate_json_stream():
            async for chunk in answer_question_streaming(processed_request):
                yield f"{json.dumps(chunk)}\n"
        return StreamingResponse(
            generate_json_stream(),
            media_type="application/x-ndjson",
            headers={
                "Cache-Control": "no-cache",
                "Connection": "keep-alive",
                "Access-Control-Allow-Origin": "*"
            }
        )
    except Exception as e:
        traceback.print_exc()
        raise HTTPException(status_code=500, detail=f"Streaming error: {str(e)}")

@app.post("/infer/evaluation-stream")
async def roadmap_evaluation_stream(request: RoadmapEvaluation): 
    """
    Get an AI roadmap evaluation based on student's past and planned courses

    Args:
    raw_data: Raw JSON data containing degree info, past courses, and next courses
    """
    try: 
        try: 
            request_id, future = await request_queue.add_request({
                    "degree_name": request.degree_name,
                    "degree_code": request.degree_code,
                    "degree_notes": request.degree_notes,
                    "passed_courses": request.passed_courses,
                    "failed_courses": request.failed_courses,
                    "next_courses": request.next_courses,
            })
        except QueueFullError:
            raise HTTPException(
                status_code=503,detail="The chatbot is VERY busy right now. Please try again in a moment."
            )
        # wait for worker to process and prepare inference data
        roadmap_request = await future
        roadmap_model = RoadmapEvaluation(**roadmap_request)
        async def generate_roadmap_json_stream():
            async for chunk in road_map_evaluation(roadmap_model):
                yield f"{json.dumps(chunk)}\n"
        
        return StreamingResponse(
            generate_roadmap_json_stream(),
            media_type="application/x-ndjson",
            headers={
                "Cache-Control": "no-cache",
                "Connection": "keep-alive",
                "Access-Control-Allow-Origin": "*"
            }
        )        

    except Exception as e:
        traceback.print_exc()
        raise HTTPException(status_code=500, detail=f"Streaming error: {str(e)}")
            
@app.post("/infer/evaluation", response_model=RoadmapResponse)
async def roadmap_evaluation(request: RoadmapEvaluation): 
    """
    Get an AI roadmap evaluation based on student's past and planned courses
    Returns a simple text response instead of streaming
    """
    try: 
        try: 
            request_id, future = await request_queue.add_request({
                "degree_name": request.degree_name,
                "degree_code": request.degree_code,
                "degree_notes": request.degree_notes,
                "passed_courses": request.passed_courses,
                "failed_courses": request.failed_courses,
                "next_courses": request.next_courses,
            })
        except QueueFullError:
            raise HTTPException(
                status_code=503,
                detail="The chatbot is VERY busy right now. Please try again in a moment."
            )
        
        # Wait for worker to process and prepare inference data
        roadmap_request = await future
        roadmap_model = RoadmapEvaluation(**roadmap_request)
        # Get the complete evaluation text
        evaluation_text = await road_map_evaluation_simple(roadmap_model)
        print(f"Evaluation generated{evaluation_text}")
        return RoadmapResponse(
            evaluation=evaluation_text,
            success=True
        )   

    except Exception as e:
        traceback.print_exc()
        return RoadmapResponse(
            evaluation=f"Error generating roadmap evaluation: {str(e)}",
            success=False
        )
    

# FAISS Vector Database endpoints
@app.post("/search", response_model=SearchResponse)
async def search(request: SearchRequest):
    """
    Search for similar questions within the FAISS database and return their question IDs
    Now with enhanced relevance filtering
    """
    try:
        if not request.question.strip():
            raise HTTPException(status_code=400, detail="Question cannot be empty")
        if request.k <= 0:
            raise HTTPException(status_code=400, detail="k must be greater than 0")
        # Check if database is initialized
        if not faissbot.is_initialized():
            raise HTTPException(
                status_code=400, 
                detail="FAISS database not initialized. Database will be initialized on startup."
            )
        # Search
        similar_question_ids = faissbot.search_query(
            request.question, 
            k=request.k, 
            min_similarity=request.min_similarity,
            max_results=request.max_results
        )
        # Return results even if empty 
        return SearchResponse(
            indices=similar_question_ids,
            query=request.question,
            k=len(similar_question_ids)
        )
        
    except HTTPException:
        raise
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Error during search: {str(e)}")

@app.get("/stats", response_model=StatsResponse)
async def get_database_stats():
    """
    Get FAISS database statistics
    """
    try:
        stats = faissbot.get_database_stats()
        return StatsResponse(**stats)
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Error retrieving stats: {str(e)}")

@app.post("/reinitialize")
async def reinitialize_database(request: ReinitializeRequest):
    """
    Reinitialize the FAISS database with new questions
    Note: This endpoint expects questions as a list of (id, question) tuples
    Not implemented yet, future work.
    """
    try:
        success = faissbot.initialize_database(request.questions)
        if success:
            stats = faissbot.get_database_stats()
            return {
                "message": "FAISS database reinitialized successfully",
                "success": True,
                "total_questions": stats["total_questions"]
            }
        else:
            raise HTTPException(status_code=500, detail="Failed to reinitialize FAISS database")       
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Error reinitializing FAISS database: {str(e)}")

@app.post("/reinitialize-from-api")
async def reinitialize_database_from_api():
    """
    Reinitialize the FAISS database by fetching fresh data from the external API
    Initialized from this side manually. Does a fetch
    Not implemented yet, future work.
    """
    try:
        print('Fetching questions from the external database API...')
        response = requests.get(f"{api_url}advisor_chatbot_questions", timeout=30)
        response.raise_for_status()
        data = response.json()
        
        # Create list of (question_id, question) tuples
        questions = [(item["question_id"], item["question"]) for item in data]
        
        success = faissbot.initialize_database(questions)
        
        if success:
            stats = faissbot.get_database_stats()
            print(f"Successfully reinitialized FAISS database with {stats['total_questions']} questions")
            return {
                "message": "FAISS database reinitialized successfully from external API",
                "success": True,
                "total_questions": stats["total_questions"]
            }
        else:
            raise HTTPException(status_code=500, detail="Failed to reinitialize FAISS database")
            
    except requests.RequestException as e:
        raise HTTPException(status_code=503, detail=f"Error fetching data from external API: {str(e)}")
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Error reinitializing FAISS database: {str(e)}")

@app.on_event("startup")
async def initialize_database():
    """
    Initialize the FAISS database by fetching questions from the external database API
    and start the request queue worker
    """
    try:
        print('Starting model server and loading model...')
        print('Fetching questions from the external database API for FAISS initialization...')
        response = requests.get(f"{api_url}advisor_chatbot_questions", timeout=30)
        response.raise_for_status()
        data = response.json()
        # Create list of (question_id, question) tuples
        questions = [(item["question_id"], item["question"]) for item in data]
        success = faissbot.initialize_database(questions)
        if success:
            print(f"Success! FAISS database initialized with {len(questions)} questions")
        else:
            print("Warning: Failed to initialize FAISS database")
    except Exception as e:
        print(f"Warning: Error initializing FAISS database: {str(e)}")
        print("FAISS functionality will not be available until manual initialization")
    # Define and start the queue processor
    async def processor_func(data):
        """Process a queued request by running streaming function"""
        # Build query
        return data
    # Start the queue worker
    try:
        await request_queue.start_worker(processor_func)
        print("Request queue worker started successfully")
    except Exception as e:
        print(f"Error starting request queue worker: {e}")

@app.on_event("shutdown")
async def shutdown_event():
    """Cleanup on shutdown"""
    try:
        await request_queue.shutdown()
        print("Request queue shut down gracefully")
    except Exception as e:
        print(f"Error during shutdown: {e}")

if __name__ == "__main__":
    import uvicorn
    print("Starting combined model server with FAISS vector database...")
    uvicorn.run(app, host="0.0.0.0", port=8001)